{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.824683  , -0.5655952 , -0.41859964], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlHUlEQVR4nO3dfXRU9YH/8c+dPAwhYSYkkIwpCdBCC5QHKyBM7R5tyRI1Wl2xrS4HWUt1pcGC9LiVXcWj2z1hdX9ttavYPW3V01ZR2oJKpRoDhHUJTwGUB0WsSqIwCQ/NTBLJ5GG+vz9cZh1FTeAm803yfp0z55h7v/PNd64wb2bmzoxjjDECAMBCnmQvAACAT0KkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWSlqkHnroIY0aNUqDBg3SjBkztH379mQtBQBgqaRE6qmnntLSpUt19913a9euXZoyZYpKSkrU0NCQjOUAACzlJOMDZmfMmKHp06frP//zPyVJsVhMhYWFuvXWW3XHHXf09nIAAJZK7e1f2NbWppqaGi1btiy+zePxqLi4WNXV1We8TjQaVTQajf8ci8V08uRJ5ebmynGcHl8zAMBdxhg1NTWpoKBAHs8nP6nX65E6fvy4Ojs7lZ+fn7A9Pz9fr7/++hmvU15ernvuuac3lgcA6EV1dXUaMWLEJ+7v9UidjWXLlmnp0qXxn8PhsIqKilRXVyefz5fElQEAzkYkElFhYaGGDBnyqeN6PVLDhg1TSkqK6uvrE7bX19crEAic8Tper1der/dj230+H5ECgD7ss16y6fWz+9LT0zV16lRVVlbGt8ViMVVWVioYDPb2cgAAFkvK031Lly7V/PnzNW3aNF144YX62c9+ppaWFt14443JWA4AwFJJidR3vvMdHTt2TMuXL1coFNL555+vP//5zx87mQIAMLAl5X1S5yoSicjv9yscDvOaFAD0QV29H+ez+wAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYq9uR2rx5s6688koVFBTIcRytXbs2Yb8xRsuXL9d5552njIwMFRcX69ChQwljTp48qblz58rn8yk7O1sLFixQc3PzOd0QAED/0+1ItbS0aMqUKXrooYfOuP++++7Tgw8+qEceeUTbtm1TZmamSkpK1NraGh8zd+5c7d+/XxUVFVq3bp02b96sm2+++exvBQCgfzLnQJJZs2ZN/OdYLGYCgYC5//7749saGxuN1+s1Tz75pDHGmAMHDhhJZseOHfEx69evN47jmPfee69LvzccDhtJJhwOn8vyAQBJ0tX7cVdfk3r77bcVCoVUXFwc3+b3+zVjxgxVV1dLkqqrq5Wdna1p06bFxxQXF8vj8Wjbtm1nnDcajSoSiSRcAAD9n6uRCoVCkqT8/PyE7fn5+fF9oVBIeXl5CftTU1OVk5MTH/NR5eXl8vv98UthYaGbywYAWKpPnN23bNkyhcPh+KWuri7ZSwIA9AJXIxUIBCRJ9fX1Cdvr6+vj+wKBgBoaGhL2d3R06OTJk/ExH+X1euXz+RIuAID+z9VIjR49WoFAQJWVlfFtkUhE27ZtUzAYlCQFg0E1NjaqpqYmPmbDhg2KxWKaMWOGm8sBAPRxqd29QnNzs9588834z2+//bb27NmjnJwcFRUVacmSJfrxj3+ssWPHavTo0brrrrtUUFCgq6++WpI0fvx4XXrppbrpppv0yCOPqL29XYsWLdJ1112ngoIC124YAKAf6O5pgxs3bjSSPnaZP3++MeaD09Dvuusuk5+fb7xer5k1a5Y5ePBgwhwnTpww119/vcnKyjI+n8/ceOONpqmpyfVTFwEAdurq/bhjjDFJbORZiUQi8vv9CofDvD4FAH1QV+/H+8TZfQCAgYlIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgrW5Fqry8XNOnT9eQIUOUl5enq6++WgcPHkwY09raqrKyMuXm5iorK0tz5sxRfX19wpja2lqVlpZq8ODBysvL0+23366Ojo5zvzUAgH6lW5GqqqpSWVmZtm7dqoqKCrW3t2v27NlqaWmJj7ntttv03HPPafXq1aqqqtKRI0d0zTXXxPd3dnaqtLRUbW1t2rJlix5//HE99thjWr58uXu3CgDQP5hz0NDQYCSZqqoqY4wxjY2NJi0tzaxevTo+5rXXXjOSTHV1tTHGmOeff954PB4TCoXiY1auXGl8Pp+JRqNd+r3hcNhIMuFw+FyWDwBIkq7ej5/Ta1LhcFiSlJOTI0mqqalRe3u7iouL42PGjRunoqIiVVdXS5Kqq6s1adIk5efnx8eUlJQoEolo//79Z/w90WhUkUgk4QIA6P/OOlKxWExLlizRRRddpIkTJ0qSQqGQ0tPTlZ2dnTA2Pz9foVAoPubDgTq9//S+MykvL5ff749fCgsLz3bZAIA+5KwjVVZWpn379mnVqlVurueMli1bpnA4HL/U1dX1+O8EACRf6tlcadGiRVq3bp02b96sESNGxLcHAgG1tbWpsbEx4dFUfX29AoFAfMz27dsT5jt99t/pMR/l9Xrl9XrPZqkAgD6sW4+kjDFatGiR1qxZow0bNmj06NEJ+6dOnaq0tDRVVlbGtx08eFC1tbUKBoOSpGAwqL1796qhoSE+pqKiQj6fTxMmTDiX2wIA6Ge69UiqrKxMTzzxhJ555hkNGTIk/hqS3+9XRkaG/H6/FixYoKVLlyonJ0c+n0+33nqrgsGgZs6cKUmaPXu2JkyYoHnz5um+++5TKBTSnXfeqbKyMh4tAQASOMYY0+XBjnPG7Y8++qj+4R/+QdIHb+b94Q9/qCeffFLRaFQlJSV6+OGHE57KO3z4sBYuXKhNmzYpMzNT8+fP14oVK5Sa2rVmRiIR+f1+hcNh+Xy+ri4fAGCJrt6PdytStiBSANC3dfV+nM/uAwBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1kpN9gIA/B9jzCfucxynF1cC2IFIARYwnZ3qaGpSZNcuNe7Yoda6OnWeOqVUn0+ZY8Zo6Ne+psFf+IJSMjOJFQYUIgUkWSwaVePWrap/7jm9f+iQ9KFHU+3HjunUX/6iExs3yn/BBcq7+mpljR9PqDBgECkgiYwxOvbiiwqtXq2OxsZPHtfWpsatW9V69KiKbr5ZWRMnEioMCJw4ASSJ6ezUiZde0pHf/e5TA/VhrYcPq/a//kvNr732qa9fAf0FkQKSpOWNNxRavVqx999P2P5eS4vW1dXpybfe0ktHjqilvT1hf+vhwzr65JPqbG7uzeUCScHTfUASxNrbFd65U9FQKL7NGKO3m5t19+7deqe5Wa2dnfKlpWni0KH6j+nTleb5v39TNr3yik6+/LKGX3opT/uhX+ORFJAE7SdOqP6Pf0zY9lZzs276n//Ra+GwTnV2ykgKt7frfxoatHjbNp1obU0YH3r6aUWPHOnFVQO9j0gBSWCMkensTNj2s/37Ff7IU3unbT9+XBUfCVL7iRN697HH1NHU1GPrBJKNSAF9WNOrrypcUyMTiyV7KUCPIFJAHxY7dUrvPPCAWuvqkr0UoEcQKcASpYWFSvuEkyBGZWVpck7Oma/Y2amjTz+tzmi0B1cHJAeRApIgze9Xzte/nrCtpKBAd3/lKxqUkhL/i5niOMr1evX/pk/XhOzsT5yvsbpaDWvX8rQf+h1OQQeSwJORoaHBoMI7d6rzf098cBxHJQUFGjF4sNa9+65OtLZqVFaWvjN6tHK93k+dz3R06MSmTRoyebIyx43jtHT0G0QKSALHceQZNEie9HR1fmT7xKFDNXHo0G7PGX3vPb33299qzPLlSvmMqAF9BU/3AUmSlpurVL/f1Tmb9+7V2//xHzzth36DSAFJMmjECOVecomcVHef0Hj/0CE17dvn6pxAshApIEkcx1HeVVcpt7jY1XnbT57Uu7/8pU5xWjr6ASIFJNnwyy9XSmamq3OeeucdNW7Z8rFPtQD6GiIFJJHjOBo8apRG/uAHSvX5XJ37yJNPKrxjh6tzAr2NSAEW8F9wgTJGjXJ30lhMR1atUvvJk+7OC/QiIgVYwOP1avTttys9EHB13lPvvKP6Z55RrK3N1XmB3kKkAEuk+nwqvOkmdyeNxVS/Zo3CO3fyTb7ok4gUYAnHcZQ1frxyLr7Y9bmPPv20zCd8DQhgMyIFWCQ1K0sjvvc9DTn/fFfnPfXWW3rr3/+dp/3Q5xApwDJpfr8K5s6Vk5Li6rwthw598N1TPO2HPoRIARbKHDNGgW9/W3Lxg2I7Ghv13qOPKnr0KKFCn0GkABt5PBp+6aXyX3ihq9NGQyHVP/MMb/JFn0GkAAs5jqO0oUM1/LLLlPop3yN1Nk689JJObtrk6pxATyFSgMV855+vwLe+5erTfqa9Xcf+9Cedqq3laT9Yj0gBFnM8Hg0rLtbgz3/e1Xnf/8tfFPrDHxQ7dcrVeQG3ESnAcp5Bg1S0aJG8n/ucq/Oe3LhRR554gkdTsBqRAiznOI4Gjxwp/7Rpksfdv7InN21SNBRydU7ATUQK6AOc1FSN+O53lT1zpqvzdkQiOvzznyva0ODqvIBbiBTQRziOo/Ouu871eZv37dNfX36Z09JhJSIF9CEZhYUquuUWOWlprs4b+v3v1fruu67OCbiBSAF9iJOSopxLLpF/+nRX5+1sbtab//qv6mhpcXVe4FwRKaCPSRk8WPlXXaWUIUNcnbe9sVEnKis52w9WIVJAH5Q1frxGlpW5Oqdpa9OR3/xGjVu3EipYg0gBfdSQSZOU9eUvuzpnLBrV8RdeUCwadXVe4GwRKaCPSh0yREW33KKML3zB1Xkju3bpvccf59EUrECkgD5sUFGRCv7+7+UZNMjVef+6ZYua9+8nVEg6IgX0YY7jyD9tmoZ+7Wuuztvx17/qyG9/q7b6elfnBbqLSAH9wHnXXafM8eNdnbP5wAGdqKriTb5Iqm5FauXKlZo8ebJ8Pp98Pp+CwaDWr18f39/a2qqysjLl5uYqKytLc+bMUf1H/iVWW1ur0tJSDR48WHl5ebr99tvV0dHhzq0BBiDHcZQ+fLiGX365UgYPdnXu0OrVaj5wwNU5ge7oVqRGjBihFStWqKamRjt37tQ3vvENXXXVVdq/f78k6bbbbtNzzz2n1atXq6qqSkeOHNE111wTv35nZ6dKS0vV1tamLVu26PHHH9djjz2m5cuXu3urgAHGcRzlXHSRhl95pZSS4tq8pq1Ntb/4hdqOH3dtTqA7HHOOr4zm5OTo/vvv17XXXqvhw4friSee0LXXXitJev311zV+/HhVV1dr5syZWr9+va644godOXJE+fn5kqRHHnlEP/rRj3Ts2DGlp6d36XdGIhH5/X6Fw2H5fL5zWT7Qr7RHIjrw/e+rIxJxbU4nLU2BOXMUuPZaebr4dxT4LF29Hz/r16Q6Ozu1atUqtbS0KBgMqqamRu3t7SouLo6PGTdunIqKilRdXS1Jqq6u1qRJk+KBkqSSkhJFIpH4o7EziUajikQiCRcAH5c6ZIjG3H23Uv1+1+Y07e06+vTTCu/cydl+6HXdjtTevXuVlZUlr9erW265RWvWrNGECRMUCoWUnp6u7OzshPH5+fkK/e/31YRCoYRAnd5/et8nKS8vl9/vj18KCwu7u2xgQHAcRxkjRyrn4ovdnTgWU+gPf1BHOOzuvMBn6HakvvSlL2nPnj3atm2bFi5cqPnz5+tAD7+wumzZMoXD4filrq6uR38f0Jd50tMVmDNHWRMmuDrv+2++qXd//WvF2ttdnRf4NN2OVHp6usaMGaOpU6eqvLxcU6ZM0QMPPKBAIKC2tjY1NjYmjK+vr1cgEJAkBQKBj53td/rn02POxOv1xs8oPH0B8MnShg5V4c03y0lNdW9SY3Ry0yaFVq/mtHT0mnN+n1QsFlM0GtXUqVOVlpamysrK+L6DBw+qtrZWwWBQkhQMBrV37141fOhbQCsqKuTz+TTB5X/1AQNdxsiRGl5a6vq8jdXVajt2zPV5gTPp1j+zli1bpssuu0xFRUVqamrSE088oU2bNumFF16Q3+/XggULtHTpUuXk5Mjn8+nWW29VMBjUzP/9yuvZs2drwoQJmjdvnu677z6FQiHdeeedKisrk9fr7ZEbCAxUTkqKAnPmKBoKKbxtm2vznjp8WG/+279p/E9+Io/LX74IfFS3ItXQ0KAbbrhBR48eld/v1+TJk/XCCy/ob//2byVJP/3pT+XxeDRnzhxFo1GVlJTo4Ycfjl8/JSVF69at08KFCxUMBpWZman58+fr3nvvdfdWAZAkpfr9yr34YjUfOKDOpibX5o0eOaK/vvyycr/+ddfmBM7knN8nlQy8Twronvpnn9W7v/qV5OJf97RhwzRq8WL5pkxxbU4MHD3+PikAfUfurFlKGzrU1Tnbjx/XyY0b1fn++67OC3wYkQIGgNTMTI256y55P+Us2rNxYsMGHX/xRd7kix5DpIABImPkSPmmTnV93qNPPaXokSOECj2CSAEDhJOaqs/Nm6chkye7Om9nS4ve+81v1NnS4uq8gESkgAElZfBgjfzBDzRoxAhX523cskUN69YpxtfuwGVEChhg0ocN65E3+R5/8UW1HzvG035wFZECBhjH41HuN76h4ZdfLnncuwtoP35ch+69V7FTp1ybEyBSwACUkpGh/Kuvdv209LZjx3S8okKGp/3gEiIFDFDeQECFCxZIjuPanKatTaE//EGRvXt52g+uIFLAAOafPl3Dr7jC1Tk7GhvVsHatOlz8GCYMXEQKGMA8Xq+GFRcrY9QoV+eN7N6thrVrXZ0TAxORAga4jFGjVDBvnjwZGa7OW//sswqtWcPTfjgnRAoY4BzHkX/aNPlcfpOvaWtTeMcOtX3o++OA7iJSAOQ4jorKypQ5bpyr8zbv26d3f/UrxdraXJ0XAweRAiBJSvX5lPv1r7v7lfOSGrdtU8sbb7g6JwYOIgVA0gdv8h126aXKv+YaKSXFvYmN0TsPPqhThw+7NycGDCIFIM5xHOV/85tKGTTI1XnbQiE1PPecOltbXZ0X/R+RApAgZcgQff5HP5LH63V13pObN6vplVc42w/dQqQAJHAcR1lf/rLyrrzS1Xljra16a8UKvX/okKvzon8jUgA+xpOWppxLLlHGyJGuzms6O3X0qaf4NAp0GZECcEYZRUUatXSpUv1+V+eN7Nql4y+8IBOLuTov+iciBeATZYwapfO+/W1X5zSdnWrcvl1tx4+7Oi/6JyIF4BM5jqOhf/M38k+f7uq8La+/rvaTJ12dE/0TkQLwqVL9fp133XXynndespeCAYhIAfhUjuMoc+xYDZs929Vv8gW6gj9xALok74orlP3VryZ7GRhgiBSALnHS05V/1VWufOV8qs/n+puF0T8RKQBd4jiOMr/4RRXccINSMjPPaa6cSy7RoBEjXFoZ+jMiBaDLHMfR0K9+Vd6CgrOeY9CIEcq95BJ50tJcXBn6KyIFoFtSMjI09p57NPRrX5Pvggu69YnpqUOHquCGG5TxhS/04ArRnxApAN2WmpWlz//TP6nwe9/r8nuo0vPy9Ll585R94YVyHKeHV4j+wt1vNwMwoAwaMUJF//iPOj5ypI6/9JI6wmGZzk7p9CedezzypKUp68tfVuBb31LmF78oh9PY0Q2O6YOfmx+JROT3+xUOh+Xz+ZK9HGBAO30X0nb8uCI7d6rlzTfVEYnIk54ub0GBhkyerKzx4+WkpPAICnFdvR/nkRSAc3I6PN7hwzX8sss0PMnrQf/C424AgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBY65witWLFCjmOoyVLlsS3tba2qqysTLm5ucrKytKcOXNUX1+fcL3a2lqVlpZq8ODBysvL0+23366Ojo5zWQoAoB8660jt2LFDv/jFLzR58uSE7bfddpuee+45rV69WlVVVTpy5Iiuueaa+P7Ozk6Vlpaqra1NW7Zs0eOPP67HHntMy5cvP/tbAQDon8xZaGpqMmPHjjUVFRXm4osvNosXLzbGGNPY2GjS0tLM6tWr42Nfe+01I8lUV1cbY4x5/vnnjcfjMaFQKD5m5cqVxufzmWg02qXfHw6HjSQTDofPZvkAgCTr6v34WT2SKisrU2lpqYqLixO219TUqL29PWH7uHHjVFRUpOrqaklSdXW1Jk2apPz8/PiYkpISRSIR7d+//4y/LxqNKhKJJFwAAP1fanevsGrVKu3atUs7duz42L5QKKT09HRlZ2cnbM/Pz1coFIqP+XCgTu8/ve9MysvLdc8993R3qQCAPq5bj6Tq6uq0ePFi/e53v9OgQYN6ak0fs2zZMoXD4filrq6u1343ACB5uhWpmpoaNTQ06IILLlBqaqpSU1NVVVWlBx98UKmpqcrPz1dbW5saGxsTrldfX69AICBJCgQCHzvb7/TPp8d8lNfrlc/nS7gAAPq/bkVq1qxZ2rt3r/bs2RO/TJs2TXPnzo3/d1pamiorK+PXOXjwoGpraxUMBiVJwWBQe/fuVUNDQ3xMRUWFfD6fJkyY4NLNAgD0B916TWrIkCGaOHFiwrbMzEzl5ubGty9YsEBLly5VTk6OfD6fbr31VgWDQc2cOVOSNHv2bE2YMEHz5s3Tfffdp1AopDvvvFNlZWXyer0u3SwAQH/Q7RMnPstPf/pTeTwezZkzR9FoVCUlJXr44Yfj+1NSUrRu3TotXLhQwWBQmZmZmj9/vu699163lwIA6OMcY4xJ9iK6KxKJyO/3KxwO8/oUAPRBXb0f57P7AADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWSk32As6GMUaSFIlEkrwSAMDZOH3/ffr+/JP0yUidOHFCklRYWJjklQAAzkVTU5P8fv8n7u+TkcrJyZEk1dbWfuqNG+gikYgKCwtVV1cnn8+X7OVYi+PUNRynruE4dY0xRk1NTSooKPjUcX0yUh7PBy+l+f1+/hB0gc/n4zh1AcepazhOXcNx+mxdeZDBiRMAAGsRKQCAtfpkpLxer+6++255vd5kL8VqHKeu4Th1DcepazhO7nLMZ53/BwBAkvTJR1IAgIGBSAEArEWkAADWIlIAAGv1yUg99NBDGjVqlAYNGqQZM2Zo+/btyV5Sr9q8ebOuvPJKFRQUyHEcrV27NmG/MUbLly/Xeeedp4yMDBUXF+vQoUMJY06ePKm5c+fK5/MpOztbCxYsUHNzcy/eip5VXl6u6dOna8iQIcrLy9PVV1+tgwcPJoxpbW1VWVmZcnNzlZWVpTlz5qi+vj5hTG1trUpLSzV48GDl5eXp9ttvV0dHR2/elB61cuVKTZ48Of7G02AwqPXr18f3c4zObMWKFXIcR0uWLIlv41j1ENPHrFq1yqSnp5tf//rXZv/+/eamm24y2dnZpr6+PtlL6zXPP/+8+Zd/+Rfzxz/+0Ugya9asSdi/YsUK4/f7zdq1a80rr7xivvnNb5rRo0ebU6dOxcdceumlZsqUKWbr1q3mv//7v82YMWPM9ddf38u3pOeUlJSYRx991Ozbt8/s2bPHXH755aaoqMg0NzfHx9xyyy2msLDQVFZWmp07d5qZM2ear371q/H9HR0dZuLEiaa4uNjs3r3bPP/882bYsGFm2bJlybhJPeLZZ581f/rTn8wbb7xhDh48aP75n//ZpKWlmX379hljOEZnsn37djNq1CgzefJks3jx4vh2jlXP6HORuvDCC01ZWVn8587OTlNQUGDKy8uTuKrk+WikYrGYCQQC5v77749va2xsNF6v1zz55JPGGGMOHDhgJJkdO3bEx6xfv944jmPee++9Xlt7b2poaDCSTFVVlTHmg2OSlpZmVq9eHR/z2muvGUmmurraGPPBPwY8Ho8JhULxMStXrjQ+n89Eo9HevQG9aOjQoeaXv/wlx+gMmpqazNixY01FRYW5+OKL45HiWPWcPvV0X1tbm2pqalRcXBzf5vF4VFxcrOrq6iSuzB5vv/22QqFQwjHy+/2aMWNG/BhVV1crOztb06ZNi48pLi6Wx+PRtm3ben3NvSEcDkv6vw8nrqmpUXt7e8JxGjdunIqKihKO06RJk5Sfnx8fU1JSokgkov379/fi6ntHZ2enVq1apZaWFgWDQY7RGZSVlam0tDThmEj8eepJfeoDZo8fP67Ozs6E/8mSlJ+fr9dffz1Jq7JLKBSSpDMeo9P7QqGQ8vLyEvanpqYqJycnPqY/icViWrJkiS666CJNnDhR0gfHID09XdnZ2QljP3qcznQcT+/rL/bu3atgMKjW1lZlZWVpzZo1mjBhgvbs2cMx+pBVq1Zp165d2rFjx8f28eep5/SpSAFno6ysTPv27dPLL7+c7KVY6Utf+pL27NmjcDis3//+95o/f76qqqqSvSyr1NXVafHixaqoqNCgQYOSvZwBpU893Tds2DClpKR87IyZ+vp6BQKBJK3KLqePw6cdo0AgoIaGhoT9HR0dOnnyZL87josWLdK6deu0ceNGjRgxIr49EAiora1NjY2NCeM/epzOdBxP7+sv0tPTNWbMGE2dOlXl5eWaMmWKHnjgAY7Rh9TU1KihoUEXXHCBUlNTlZqaqqqqKj344INKTU1Vfn4+x6qH9KlIpaena+rUqaqsrIxvi8ViqqysVDAYTOLK7DF69GgFAoGEYxSJRLRt27b4MQoGg2psbFRNTU18zIYNGxSLxTRjxoxeX3NPMMZo0aJFWrNmjTZs2KDRo0cn7J86darS0tISjtPBgwdVW1ubcJz27t2bEPSKigr5fD5NmDChd25IEsRiMUWjUY7Rh8yaNUt79+7Vnj174pdp06Zp7ty58f/mWPWQZJ+50V2rVq0yXq/XPPbYY+bAgQPm5ptvNtnZ2QlnzPR3TU1NZvfu3Wb37t1GkvnJT35idu/ebQ4fPmyM+eAU9OzsbPPMM8+YV1991Vx11VVnPAX9K1/5itm2bZt5+eWXzdixY/vVKegLFy40fr/fbNq0yRw9ejR+ef/99+NjbrnlFlNUVGQ2bNhgdu7caYLBoAkGg/H9p08Znj17ttmzZ4/585//bIYPH96vThm+4447TFVVlXn77bfNq6++au644w7jOI558cUXjTEco0/z4bP7jOFY9ZQ+FyljjPn5z39uioqKTHp6urnwwgvN1q1bk72kXrVx40Yj6WOX+fPnG2M+OA39rrvuMvn5+cbr9ZpZs2aZgwcPJsxx4sQJc/3115usrCzj8/nMjTfeaJqampJwa3rGmY6PJPPoo4/Gx5w6dcp8//vfN0OHDjWDBw82f/d3f2eOHj2aMM8777xjLrvsMpORkWGGDRtmfvjDH5r29vZevjU957vf/a4ZOXKkSU9PN8OHDzezZs2KB8oYjtGn+WikOFY9g6/qAABYq0+9JgUAGFiIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsNb/B14YhX5piReSAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Sequential(\n",
       "   (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=11, bias=True)\n",
       " ),\n",
       " Sequential(\n",
       "   (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=11, bias=True)\n",
       " ))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#计算动作的模型,也是真正要用的模型\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 11),\n",
    ")\n",
    "\n",
    "#经验网络,用于评估一个状态的分数\n",
    "next_model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 11),\n",
    ")\n",
    "\n",
    "#把model的参数复制给next_model\n",
    "next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "model, next_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(7, 0.7999999999999998)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "def get_action(state):\n",
    "    #走神经网络,得到一个动作\n",
    "    state = torch.FloatTensor(state).reshape(1, 3)\n",
    "    action = model(state).argmax().item()\n",
    "\n",
    "    if random.random() < 0.01:\n",
    "        action = random.choice(range(11))\n",
    "\n",
    "    #离散动作连续化\n",
    "    action_continuous = action\n",
    "    action_continuous /= 10\n",
    "    action_continuous *= 4\n",
    "    action_continuous -= 2\n",
    "\n",
    "    return action, action_continuous\n",
    "\n",
    "\n",
    "get_action([0.29292667, 0.9561349, 1.0957013])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((200, 0), 200)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action, action_continuous = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step([action_continuous])\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    update_count = len(datas) - old_count\n",
    "    drop_count = max(len(datas) - 5000, 0)\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 5000:\n",
    "        datas.pop(0)\n",
    "\n",
    "    return update_count, drop_count\n",
    "\n",
    "\n",
    "update_data(), len(datas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1387/1416897299.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[-9.8579e-01, -1.6799e-01,  3.3351e-01],\n",
       "         [-9.7823e-01, -2.0754e-01,  1.4936e-01],\n",
       "         [-9.7952e-01, -2.0136e-01, -2.7487e-01],\n",
       "         [-9.3178e-01, -3.6301e-01, -2.5567e-01],\n",
       "         [-9.8903e-01, -1.4770e-01,  3.8444e-01],\n",
       "         [-9.7384e-01, -2.2724e-01,  2.2905e-01],\n",
       "         [-9.7185e-01, -2.3560e-01,  7.3057e-03],\n",
       "         [-9.9049e-01, -1.3762e-01, -3.2882e-01],\n",
       "         [-9.6954e-01, -2.4493e-01,  1.2423e-02],\n",
       "         [-9.6923e-01, -2.4616e-01,  2.3374e-01],\n",
       "         [-9.9466e-01, -1.0322e-01, -2.7013e-01],\n",
       "         [-9.9753e-01, -7.0183e-02,  1.4830e-01],\n",
       "         [-9.7093e-01, -2.3936e-01, -9.0146e-02],\n",
       "         [-9.9604e-01, -8.8886e-02, -1.3484e-01],\n",
       "         [-9.9184e-01, -1.2750e-01,  1.7671e-01],\n",
       "         [-9.9550e-01, -9.4765e-02, -3.8438e-01],\n",
       "         [-9.8710e-01, -1.6010e-01, -3.9193e-01],\n",
       "         [-9.9130e-01, -1.3166e-01, -3.2390e-01],\n",
       "         [-9.9159e-01, -1.2943e-01,  2.9866e-01],\n",
       "         [-9.5721e-01, -2.8938e-01, -6.0452e-01],\n",
       "         [-9.6585e-01, -2.5910e-01,  9.8354e-02],\n",
       "         [-9.7223e-01, -2.3404e-01,  1.9169e-01],\n",
       "         [-9.7661e-01, -2.1501e-01, -3.0765e-01],\n",
       "         [-9.9826e-01, -5.9039e-02, -7.6061e-02],\n",
       "         [-9.7300e-01, -2.3083e-01,  2.1092e-01],\n",
       "         [-9.9718e-01, -7.5053e-02, -8.0490e-02],\n",
       "         [-9.9179e-01, -1.2785e-01, -3.1977e-01],\n",
       "         [-9.6561e-01, -2.5998e-01,  8.6603e-01],\n",
       "         [-9.7999e-01, -1.9906e-01, -2.3887e-01],\n",
       "         [-9.7922e-01, -2.0282e-01,  3.0322e-01],\n",
       "         [-9.7043e-01, -2.4136e-01, -6.6721e-02],\n",
       "         [-9.9803e-01, -6.2785e-02,  7.5393e-02],\n",
       "         [-9.7534e-01, -2.2071e-01,  3.3539e-01],\n",
       "         [-9.8533e-01, -1.7067e-01,  2.2719e-01],\n",
       "         [-9.7243e-01, -2.3320e-01, -4.9395e-02],\n",
       "         [-9.7061e-01, -2.4065e-01,  1.3616e-01],\n",
       "         [-9.7018e-01, -2.4240e-01, -1.9334e-01],\n",
       "         [-9.9720e-01, -7.4736e-02,  7.5439e-03],\n",
       "         [-9.7501e-01, -2.2217e-01, -2.0374e-01],\n",
       "         [-9.9826e-01, -5.9022e-02, -3.4031e-04],\n",
       "         [-9.6995e-01, -2.4330e-01,  9.8927e-02],\n",
       "         [-9.9753e-01, -7.0265e-02, -2.1624e-01],\n",
       "         [-9.9437e-01, -1.0599e-01,  1.8868e-01],\n",
       "         [-9.7196e-01, -2.3516e-01, -1.2774e-01],\n",
       "         [-9.9629e-01, -8.6101e-02, -1.9662e-01],\n",
       "         [-9.9652e-01, -8.3383e-02,  8.3795e-01],\n",
       "         [-9.6984e-01, -2.4374e-01, -2.7342e-02],\n",
       "         [-9.6711e-01, -2.5435e-01,  1.6912e-01],\n",
       "         [-9.8063e-01, -1.9586e-01, -2.8783e-01],\n",
       "         [-9.9671e-01, -8.1046e-02, -2.7545e-01],\n",
       "         [-9.9436e-01, -1.0606e-01,  6.8931e-02],\n",
       "         [-9.7812e-01, -2.0803e-01,  2.9236e-01],\n",
       "         [-9.8683e-01, -1.6179e-01,  3.3256e-01],\n",
       "         [-9.9048e-01, -1.3764e-01, -5.5574e-01],\n",
       "         [-9.9709e-01, -7.6199e-02, -1.0386e-01],\n",
       "         [-9.7771e-01, -2.0995e-01, -2.5037e-01],\n",
       "         [-9.2529e-01, -3.7927e-01,  6.9955e-02],\n",
       "         [-9.9725e-01, -7.4154e-02, -4.1014e-02],\n",
       "         [-9.7403e-01, -2.2641e-01,  1.6766e-01],\n",
       "         [-9.9554e-01, -9.4332e-02,  1.8908e-01],\n",
       "         [-9.9232e-01, -1.2368e-01,  2.8706e-01],\n",
       "         [-9.7152e-01, -2.3695e-01, -1.1310e-01],\n",
       "         [-9.7268e-01, -2.3215e-01,  1.1785e-01],\n",
       "         [-9.9802e-01, -6.2835e-02, -1.4893e-01]]),\n",
       " tensor([[7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [8],\n",
       "         [8],\n",
       "         [7],\n",
       "         [7],\n",
       "         [8],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [8],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7]]),\n",
       " tensor([[-8.8493],\n",
       "         [-8.6026],\n",
       "         [-8.6450],\n",
       "         [-7.6806],\n",
       "         [-8.9756],\n",
       "         [-8.4877],\n",
       "         [-8.4324],\n",
       "         [-9.0327],\n",
       "         [-8.3768],\n",
       "         [-8.3748],\n",
       "         [-9.2385],\n",
       "         [-9.4360],\n",
       "         [-8.4108],\n",
       "         [-9.3208],\n",
       "         [-9.0864],\n",
       "         [-9.2985],\n",
       "         [-8.9020],\n",
       "         [-9.0685],\n",
       "         [-9.0805],\n",
       "         [-8.1491],\n",
       "         [-8.2931],\n",
       "         [-8.4454],\n",
       "         [-8.5651],\n",
       "         [-9.5031],\n",
       "         [-8.4654],\n",
       "         [-9.4045],\n",
       "         [-9.0914],\n",
       "         [-8.3619],\n",
       "         [-8.6570],\n",
       "         [-8.6379],\n",
       "         [-8.3985],\n",
       "         [-9.4800],\n",
       "         [-8.5328],\n",
       "         [-8.8272],\n",
       "         [-8.4470],\n",
       "         [-8.4041],\n",
       "         [-8.3955],\n",
       "         [-9.4058],\n",
       "         [-8.5169],\n",
       "         [-9.5027],\n",
       "         [-8.3875],\n",
       "         [-9.4380],\n",
       "         [-9.2179],\n",
       "         [-8.4367],\n",
       "         [-9.3399],\n",
       "         [-9.4229],\n",
       "         [-8.3839],\n",
       "         [-8.3234],\n",
       "         [-8.6788],\n",
       "         [-9.3746],\n",
       "         [-9.2144],\n",
       "         [-8.6060],\n",
       "         [-8.8867],\n",
       "         [-9.0534],\n",
       "         [-9.3979],\n",
       "         [-8.5922],\n",
       "         [-7.5778],\n",
       "         [-9.4096],\n",
       "         [-8.4902],\n",
       "         [-9.2892],\n",
       "         [-9.1148],\n",
       "         [-8.4257],\n",
       "         [-8.4545],\n",
       "         [-9.4814]]),\n",
       " tensor([[-9.8291e-01, -1.8411e-01,  3.2752e-01],\n",
       "         [-9.7703e-01, -2.1310e-01,  1.1370e-01],\n",
       "         [-9.8248e-01, -1.8636e-01, -3.0589e-01],\n",
       "         [-9.3899e-01, -3.4393e-01, -4.0793e-01],\n",
       "         [-9.8593e-01, -1.6714e-01,  3.9366e-01],\n",
       "         [-9.7177e-01, -2.3593e-01,  1.7862e-01],\n",
       "         [-9.7243e-01, -2.3320e-01, -4.9395e-02],\n",
       "         [-9.9251e-01, -1.2215e-01, -3.1204e-01],\n",
       "         [-9.7017e-01, -2.4244e-01, -5.1272e-02],\n",
       "         [-9.6711e-01, -2.5435e-01,  1.6912e-01],\n",
       "         [-9.9577e-01, -9.1896e-02, -2.2754e-01],\n",
       "         [-9.9672e-01, -8.0936e-02,  2.1567e-01],\n",
       "         [-9.7269e-01, -2.3209e-01, -1.4967e-01],\n",
       "         [-9.9640e-01, -8.4827e-02, -8.1501e-02],\n",
       "         [-9.9051e-01, -1.3747e-01,  2.0108e-01],\n",
       "         [-9.9671e-01, -8.1046e-02, -2.7545e-01],\n",
       "         [-9.8962e-01, -1.4369e-01, -3.3200e-01],\n",
       "         [-9.9317e-01, -1.1664e-01, -3.0265e-01],\n",
       "         [-9.8938e-01, -1.4536e-01,  3.2158e-01],\n",
       "         [-9.6600e-01, -2.5854e-01, -6.4156e-01],\n",
       "         [-9.6554e-01, -2.6026e-01,  2.4027e-02],\n",
       "         [-9.7061e-01, -2.4065e-01,  1.3616e-01],\n",
       "         [-9.8021e-01, -1.9794e-01, -3.4891e-01],\n",
       "         [-9.9826e-01, -5.9022e-02, -3.4031e-04],\n",
       "         [-9.7114e-01, -2.3850e-01,  1.5780e-01],\n",
       "         [-9.9724e-01, -7.4216e-02, -1.6780e-02],\n",
       "         [-9.9357e-01, -1.1318e-01, -2.9566e-01],\n",
       "         [-9.5458e-01, -2.9796e-01,  7.9104e-01],\n",
       "         [-9.8257e-01, -1.8590e-01, -2.6817e-01],\n",
       "         [-9.7638e-01, -2.1607e-01,  2.7110e-01],\n",
       "         [-9.7196e-01, -2.3516e-01, -1.2774e-01],\n",
       "         [-9.9753e-01, -7.0183e-02,  1.4830e-01],\n",
       "         [-9.7204e-01, -2.3482e-01,  2.8985e-01],\n",
       "         [-9.8340e-01, -1.8145e-01,  2.1919e-01],\n",
       "         [-9.7363e-01, -2.2813e-01, -1.0430e-01],\n",
       "         [-9.6969e-01, -2.4432e-01,  7.5666e-02],\n",
       "         [-9.7319e-01, -2.3001e-01, -2.5515e-01],\n",
       "         [-9.9693e-01, -7.8300e-02,  7.1492e-02],\n",
       "         [-9.7771e-01, -2.0995e-01, -2.5037e-01],\n",
       "         [-9.9803e-01, -6.2785e-02,  7.5393e-02],\n",
       "         [-9.6951e-01, -2.4506e-01,  3.6455e-02],\n",
       "         [-9.9802e-01, -6.2835e-02, -1.4893e-01],\n",
       "         [-9.9309e-01, -1.1738e-01,  2.2919e-01],\n",
       "         [-9.7408e-01, -2.2620e-01, -1.8412e-01],\n",
       "         [-9.9687e-01, -7.9066e-02, -1.4119e-01],\n",
       "         [-9.9179e-01, -1.2790e-01,  8.9541e-01],\n",
       "         [-9.7093e-01, -2.3936e-01, -9.0146e-02],\n",
       "         [-9.6585e-01, -2.5910e-01,  9.8354e-02],\n",
       "         [-9.8359e-01, -1.8040e-01, -3.1472e-01],\n",
       "         [-9.9753e-01, -7.0265e-02, -2.1624e-01],\n",
       "         [-9.9376e-01, -1.1150e-01,  1.0939e-01],\n",
       "         [-9.7538e-01, -2.2055e-01,  2.5633e-01],\n",
       "         [-9.8401e-01, -1.7811e-01,  3.3122e-01],\n",
       "         [-9.9349e-01, -1.1388e-01, -4.7897e-01],\n",
       "         [-9.9725e-01, -7.4154e-02, -4.1014e-02],\n",
       "         [-9.8063e-01, -1.9586e-01, -2.8783e-01],\n",
       "         [-9.2707e-01, -3.7490e-01, -9.4499e-02],\n",
       "         [-9.9716e-01, -7.5319e-02,  2.3371e-02],\n",
       "         [-9.7268e-01, -2.3215e-01,  1.1785e-01],\n",
       "         [-9.9435e-01, -1.0619e-01,  2.3833e-01],\n",
       "         [-9.9026e-01, -1.3926e-01,  3.1430e-01],\n",
       "         [-9.7351e-01, -2.2864e-01, -1.7082e-01],\n",
       "         [-9.7194e-01, -2.3525e-01,  6.3740e-02],\n",
       "         [-9.9826e-01, -5.9039e-02, -7.6061e-02]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 3]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    action = torch.LongTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 3]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state, action, reward, next_state, over"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1995],\n",
       "        [0.1620],\n",
       "        [0.1085],\n",
       "        [0.1003],\n",
       "        [0.2113],\n",
       "        [0.1745],\n",
       "        [0.1376],\n",
       "        [0.1069],\n",
       "        [0.1374],\n",
       "        [0.1735],\n",
       "        [0.1107],\n",
       "        [0.1682],\n",
       "        [0.1230],\n",
       "        [0.1258],\n",
       "        [0.1708],\n",
       "        [0.1120],\n",
       "        [0.1072],\n",
       "        [0.1074],\n",
       "        [0.1935],\n",
       "        [0.1472],\n",
       "        [0.1489],\n",
       "        [0.1673],\n",
       "        [0.1055],\n",
       "        [0.1352],\n",
       "        [0.1710],\n",
       "        [0.1339],\n",
       "        [0.1077],\n",
       "        [0.3016],\n",
       "        [0.1115],\n",
       "        [0.1916],\n",
       "        [0.1260],\n",
       "        [0.1569],\n",
       "        [0.1967],\n",
       "        [0.1778],\n",
       "        [0.1292],\n",
       "        [0.1569],\n",
       "        [0.1135],\n",
       "        [0.1462],\n",
       "        [0.1136],\n",
       "        [0.1457],\n",
       "        [0.1507],\n",
       "        [0.1160],\n",
       "        [0.1738],\n",
       "        [0.1190],\n",
       "        [0.1178],\n",
       "        [0.3173],\n",
       "        [0.1317],\n",
       "        [0.1612],\n",
       "        [0.1076],\n",
       "        [0.1115],\n",
       "        [0.1538],\n",
       "        [0.1891],\n",
       "        [0.1995],\n",
       "        [0.1510],\n",
       "        [0.1307],\n",
       "        [0.1103],\n",
       "        [0.1302],\n",
       "        [0.1395],\n",
       "        [0.1638],\n",
       "        [0.1744],\n",
       "        [0.1914],\n",
       "        [0.1203],\n",
       "        [0.1548],\n",
       "        [0.1249]], grad_fn=<GatherBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_value(state, action):\n",
    "    #使用状态计算出动作的logits\n",
    "    #[b, 3] -> [b, 11]\n",
    "    value = model(state)\n",
    "\n",
    "    #根据实际使用的action取出每一个值\n",
    "    #这个值就是模型评估的在该状态下,执行动作的分数\n",
    "    #在执行动作前,显然并不知道会得到的反馈和next_state\n",
    "    #所以这里不能也不需要考虑next_state和reward\n",
    "    #[b, 11] -> [b, 1]\n",
    "    value = value.gather(dim=1, index=action)\n",
    "\n",
    "    return value\n",
    "\n",
    "\n",
    "get_value(state, action)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-8.6556],\n",
       "        [-8.4499],\n",
       "        [-8.5408],\n",
       "        [-7.5875],\n",
       "        [-8.7675],\n",
       "        [-8.3262],\n",
       "        [-8.3058],\n",
       "        [-8.9266],\n",
       "        [-8.2513],\n",
       "        [-8.2168],\n",
       "        [-9.1265],\n",
       "        [-9.2604],\n",
       "        [-8.2957],\n",
       "        [-9.1900],\n",
       "        [-8.9151],\n",
       "        [-9.1893],\n",
       "        [-8.7976],\n",
       "        [-8.9618],\n",
       "        [-8.8867],\n",
       "        [-7.9932],\n",
       "        [-8.1584],\n",
       "        [-8.2916],\n",
       "        [-8.4645],\n",
       "        [-9.3603],\n",
       "        [-8.3078],\n",
       "        [-9.2645],\n",
       "        [-8.9843],\n",
       "        [-8.0877],\n",
       "        [-8.5498],\n",
       "        [-8.4575],\n",
       "        [-8.2819],\n",
       "        [-9.3151],\n",
       "        [-8.3504],\n",
       "        [-8.6549],\n",
       "        [-8.3273],\n",
       "        [-8.2602],\n",
       "        [-8.2883],\n",
       "        [-9.2534],\n",
       "        [-8.4088],\n",
       "        [-9.3489],\n",
       "        [-8.2494],\n",
       "        [-9.3156],\n",
       "        [-9.2179],\n",
       "        [-8.3240],\n",
       "        [-9.2170],\n",
       "        [-9.1038],\n",
       "        [-8.2634],\n",
       "        [-8.1774],\n",
       "        [-8.5752],\n",
       "        [-9.2609],\n",
       "        [-9.0577],\n",
       "        [-8.4290],\n",
       "        [-8.6919],\n",
       "        [-8.9222],\n",
       "        [-9.2612],\n",
       "        [-8.4868],\n",
       "        [-7.4692],\n",
       "        [-9.2642],\n",
       "        [-8.3384],\n",
       "        [-9.1103],\n",
       "        [-8.9223],\n",
       "        [-8.3121],\n",
       "        [-8.3116],\n",
       "        [-9.3489]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #上面已经把模型认为的状态下执行动作的分数给评估出来了\n",
    "    #下面使用next_state和reward计算真实的分数\n",
    "    #针对一个状态,它到底应该多少分,可以使用以往模型积累的经验评估\n",
    "    #这也是没办法的办法,因为显然没有精确解,这里使用延迟更新的next_model评估\n",
    "\n",
    "    #使用next_state计算下一个状态的分数\n",
    "    #[b, 3] -> [b, 11]\n",
    "    with torch.no_grad():\n",
    "        target = next_model(next_state)\n",
    "    \"\"\"以下是主要的Double DQN和DQN的区别\"\"\"\n",
    "    #取所有动作中分数最大的\n",
    "    #[b, 11] -> [b]\n",
    "    #target = target.max(dim=1)[0]\n",
    "\n",
    "    #使用model计算下一个状态的分数\n",
    "    #[b, 3] -> [b, 11]\n",
    "    with torch.no_grad():\n",
    "        model_target = model(next_state)\n",
    "\n",
    "    #取分数最高的下标\n",
    "    #[b, 11] -> [b, 1]\n",
    "    model_target = model_target.max(dim=1)[1]\n",
    "    model_target = model_target.reshape(-1, 1)\n",
    "\n",
    "    #以这个下标取next_value当中的值\n",
    "    #[b, 11] -> [b]\n",
    "    target = target.gather(dim=1, index=model_target)\n",
    "    \"\"\"以上是主要的Double DQN和DQN的区别\"\"\"\n",
    "\n",
    "    #下一个状态的分数乘以一个系数,相当于权重\n",
    "    target *= 0.98\n",
    "\n",
    "    #如果next_state已经游戏结束,则next_state的分数是0\n",
    "    #因为如果下一步已经游戏结束,显然不需要再继续玩下去,也就不需要考虑next_state了.\n",
    "    #[b, 1] * [b, 1] -> [b, 1]\n",
    "    target *= (1 - over)\n",
    "\n",
    "    #加上reward就是最终的分数\n",
    "    #[b, 1] + [b, 1] -> [b, 1]\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1660.0958465940482"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        _, action_continuous = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step([action_continuous])\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 400 200 0 -1471.570730773053\n",
      "20 4400 200 0 -1029.2287193828015\n",
      "40 5000 200 200 -1240.993090267291\n",
      "60 5000 200 200 -362.71653557993653\n",
      "80 5000 200 200 -752.8122188247723\n",
      "100 5000 200 200 -562.2982438485376\n",
      "120 5000 200 200 -329.80055173868806\n",
      "140 5000 200 200 -293.0991555927233\n",
      "160 5000 200 200 -939.51708431721\n",
      "180 5000 200 200 -383.6340844127264\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(200):\n",
    "        #更新N条数据\n",
    "        update_count, drop_count = update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算一批样本的value和target\n",
    "            value = get_value(state, action)\n",
    "            target = get_target(reward, next_state, over)\n",
    "\n",
    "            #更新参数\n",
    "            loss = loss_fn(value, target)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            #把model的参数复制给next_model\n",
    "            if (i + 1) % 50 == 0:\n",
    "                next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "        if epoch % 20 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(20)]) / 20\n",
    "            print(epoch, len(datas), update_count, drop_count, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlXElEQVR4nO3df3DU9YH/8dfuJtmQhN2QQLJkSAqjnJjyQw0Cex3PO0kJNOW04o3nMJSznFYaPJDWq/QUR6czOPb7PVsrYqc/xOlU6XGKLRSsmaBBSwgQoIWIqU61ieIm/MpuAmST7L6/f1j260KQBLLZd5bnY2ZnzOfz3t3352OyTz67n911GGOMAACwkDPZEwAA4EKIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkmL1Nq1azV+/HhlZmZq5syZ2r17d7KmAgCwVFIi9etf/1orV67Uo48+qn379mnatGmqqKhQW1tbMqYDALCUIxkfMDtz5kzdeOONeuaZZyRJ0WhUxcXFuv/++/XQQw8N9XQAAJZKG+o77O7uVkNDg1atWhVb5nQ6VV5errq6uj6vEw6HFQ6HYz9Ho1GdOHFC+fn5cjgcCZ8zAGBwGWPU0dGhoqIiOZ0XflJvyCN17NgxRSIRFRYWxi0vLCzUu+++2+d11qxZo8cee2wopgcAGEItLS0aN27cBdcPeaQuxapVq7Ry5crYz8FgUCUlJWppaZHH40nizAAAlyIUCqm4uFgjR4783HFDHqnRo0fL5XKptbU1bnlra6t8Pl+f13G73XK73ect93g8RAoAhrGLvWQz5Gf3ZWRkqKysTDU1NbFl0WhUNTU18vv9Qz0dAIDFkvJ038qVK7V48WJNnz5dM2bM0A9/+EOdOnVKd999dzKmAwCwVFIideedd+ro0aNavXq1AoGArrvuOr322mvnnUwBALiyJeV9UpcrFArJ6/UqGAzymhQADEP9fRzns/sAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWGvAkdqxY4fmz5+voqIiORwOvfrqq3HrjTFavXq1xo4dqxEjRqi8vFzvvfde3JgTJ05o4cKF8ng8ys3N1ZIlS9TZ2XlZGwIASD0DjtSpU6c0bdo0rV27ts/1Tz75pJ5++mk999xzqq+vV3Z2tioqKtTV1RUbs3DhQjU2Nqq6ulpbtmzRjh07dO+99176VgAAUpO5DJLMpk2bYj9Ho1Hj8/nMD37wg9iy9vZ243a7zUsvvWSMMeadd94xksyePXtiY7Zt22YcDof5+OOP+3W/wWDQSDLBYPBypg8ASJL+Po4P6mtSH3zwgQKBgMrLy2PLvF6vZs6cqbq6OklSXV2dcnNzNX369NiY8vJyOZ1O1dfX93m74XBYoVAo7gIASH2DGqlAICBJKiwsjFteWFgYWxcIBFRQUBC3Pi0tTXl5ebEx51qzZo28Xm/sUlxcPJjTBgBYalic3bdq1SoFg8HYpaWlJdlTAgAMgUGNlM/nkyS1trbGLW9tbY2t8/l8amtri1vf29urEydOxMacy+12y+PxxF0AAKlvUCM1YcIE+Xw+1dTUxJaFQiHV19fL7/dLkvx+v9rb29XQ0BAbs337dkWjUc2cOXMwpwMAGObSBnqFzs5Ovf/++7GfP/jgAx04cEB5eXkqKSnRihUr9P3vf18TJ07UhAkT9Mgjj6ioqEi33XabJOnaa6/V3Llzdc899+i5555TT0+Pli1bpn/9139VUVHRoG0YACAFDPS0wTfeeMNIOu+yePFiY8ynp6E/8sgjprCw0LjdbjN79mzT1NQUdxvHjx83d911l8nJyTEej8fcfffdpqOjY9BPXQQA2Km/j+MOY4xJYiMvSSgUktfrVTAY5PUpABiG+vs4PizO7gMAXJmIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWgP+gFkAidd9/LhONTUp3Nam6Jkzcrrdyhg9WllXXy332LFyOBzJniIwJIgUYAljjCKnT+tEba2O19Sou7VVkdOnZXp7JZdLrhEjlJ6Xp7ybbtLoOXOUlptLrJDyiBRgie6jR9Xy858ruGuXdO7nPkciinR2KtLZqSMvvqjgvn0qvvdeZV91VXImCwwRXpMCLBBua9PHL7zQd6DOZYxOHT6slp/+VGeam4dmgkCSECkgySJdXTq6dava6+ouHqjPOHX4sD75n/9RTzCYwNkByUWkgCQyxuj0X/6i1lde+fS1p4FdWSd37FB7XZ1MNJqYCQJJRqSAJGt95ZXLuv6x119X5MyZQZoNYBciBSSTMeo8fPiybuL0++/L9PQM0oQAuxApIIk6GxsJDPA5iBSQJCYaVdvmzYp2dSV7KoC1iBSQJCYa5YQH4CKIFJAskQiRAi6CSAFJYqJRiUgBn4tIAUliOJICLopIAUnCkRRwcUQKSBaOpICLIlJAkvB0H3BxRApIEp7uAy6OSAFJwpEUcHFECkiWSOTTC4ALIlJAkvCJE8DFESkgSUwkwmtSwEUQKSBJejs61Hv69GXfTtqoUXI4+VNGauI3G0iS8Mcfq+fo0cu+nZFf/KIcGRmDMCPAPkQKGOYcaWnJngKQMEQKGOYc6elyOBzJngaQEEQKGOYcLpdEpJCiiBQwzPF0H1IZkQKGOWd6OkdSSFlEChjuXK5kzwBIGCIFDHPOtDSOpJCyiBQwzDnS00WikKqIFDDMcXYfUhmRAoY5BydOIIURKWCYc3IKOlIYkQKGOQcnTiCFESkgCYwxg3ZbvJkXqYxIAUkyWF946HA6+ew+pCwiBSSDMTK9vcmeBWA9IgUkgzGffjMvgM9FpIBk4EgK6BciBSSBIVJAvxApIBmIFNAvRApIEiIFXByRApLBGEU5cQK4KCIFJAGvSQH9Q6SAZCBSQL8QKSAZiBTQL0QKSAYiBfQLkQKSwPCJE0C/ECkgCUwkot6Ojsu+HUd6uhwZGYMwI8BORApIgsjp0+o8ePCybyezpESZRUWDMCPATkQKGMYcTqfk5M8YqWtAv91r1qzRjTfeqJEjR6qgoEC33Xabmpqa4sZ0dXWpqqpK+fn5ysnJ0YIFC9Ta2ho3prm5WZWVlcrKylJBQYEefPBB9fIiMjBgDpdLDpcr2dMAEmZAkaqtrVVVVZV27dql6upq9fT0aM6cOTp16lRszAMPPKDNmzdr48aNqq2t1ZEjR3T77bfH1kciEVVWVqq7u1s7d+7UCy+8oPXr12v16tWDt1XAFcLhdH56NAWkKIe5jO+xPnr0qAoKClRbW6t/+Id/UDAY1JgxY/Tiiy/qjjvukCS9++67uvbaa1VXV6dZs2Zp27Zt+upXv6ojR46osLBQkvTcc8/pu9/9ro4ePaqMfrwIHAqF5PV6FQwG5fF4LnX6QNJ0ffKJGr/5zcu+nZwpUzR++XK5CwoGYVbA0Onv4/hl/RMsGAxKkvLy8iRJDQ0N6unpUXl5eWzMpEmTVFJSorq6OklSXV2dpkyZEguUJFVUVCgUCqmxsbHP+wmHwwqFQnEXAH87kuLpPqSwS45UNBrVihUr9KUvfUmTJ0+WJAUCAWVkZCg3NzdubGFhoQKBQGzMZwN1dv3ZdX1Zs2aNvF5v7FJcXHyp0wZSisPl4sQJpLRL/u2uqqrSoUOHtGHDhsGcT59WrVqlYDAYu7S0tCT8PoFhgdekkOLSLuVKy5Yt05YtW7Rjxw6NGzcuttzn86m7u1vt7e1xR1Otra3y+XyxMbt37467vbNn/50dcy632y23230pUwVSGidOINUN6LfbGKNly5Zp06ZN2r59uyZMmBC3vqysTOnp6aqpqYkta2pqUnNzs/x+vyTJ7/fr4MGDamtri42prq6Wx+NRaWnp5WwLcMXh6T6kugEdSVVVVenFF1/Ub37zG40cOTL2GpLX69WIESPk9Xq1ZMkSrVy5Unl5efJ4PLr//vvl9/s1a9YsSdKcOXNUWlqqRYsW6cknn1QgENDDDz+sqqoqjpaAgeJICiluQJFat26dJOkf//Ef45Y///zz+rd/+zdJ0lNPPSWn06kFCxYoHA6roqJCzz77bGysy+XSli1btHTpUvn9fmVnZ2vx4sV6/PHHL29LgCsQb+ZFqhtQpPrzlqrMzEytXbtWa9euveCYL3zhC9q6detA7hpAH/hYJKQ6fruBIXYZ758/j8Pl4uk+pDR+u4EkGLQvPHQ4Pr0AKYpIAUkwmN/K6yBSSGFECkgC09OT7CkAwwKRApIgSqSAfiFSQBJwJAX0D5ECkoAjKaB/iBSQBBxJAf1DpIAkiA7i2X1AKiNSQBKY7u5kTwEYFogUkASD+T4pIJVd0vdJAbg8nz1xwhijqKRINKro3352OhzKcDp5oy6ueEQKSALT06OoMWrv7taHnZ06dPKk3guFdLSrS6d6ezXJ69V3p0xRGpHCFY5IAUnQcfq0tn70kd5ubVVnb6/GZ2frurw8+bKyNDItTd6MDLkIFECkgKFijFEkEtHhw4f1vXXr1NrSonnjxmn66NEalZGhTJdrwE/vOdL4E0Zq4zccGCLd3d166aWX9Itf/ELXnjypFWVl8mZkSLq0D4l1ZWXJO336YE8TsAqRAoZANBrVL3/5S/3qV7/Sfffdp4lbtsjZ2Xl5N+pwcCSFlMcp6ECC9fb26te//rWef/55PfbYY/qXf/kXuf92BHVZiBSuAPyGAwkUjUa1c+dO/fSnP9WqVat00003Dd6NEylcATiSAhLozJkz+slPfqKbb75Zs2fPHtz3PREpXAGIFJAgxhj97ne/08mTJ3X33XcrMzNzUG/fIclJpJDiiBSQIO3t7fr5z3+ur3/96yopKRn8T49wOORwuQb3NgHLECkgAYwx+sMf/qCenh595StfScyd8HQfrgBECkiAnp4evfHGGyovL9fIkSMTcyd9RMoYo3A4rG4+ZR0pgn+GAQnQ2tqq999/X3fccceAr/vxqVPaf+KEOnp6NCYzU/4xY5Sdnn7eOIf6/sSJ119/XWPGjNGsWbMuZeqAVYgUkADHjx+XMUY+n6/f1zHG6IPOTj26f78+7OxUVyQiT3q6Jo8apf9z441Kd57zxEcfR1K9vb16+umnNXfuXM2YMUPOc68DDDP8BgMJEAwGlZmZqaysrH6fMPGXzk7d84c/6HAwqDORiIykYE+P/tDWpuX19Tre1RV/hT4iVV9fr/r6er311ltqa2sbpK0BkodIAQkQDoeVmZkpt9vd7+v8sLFRwc98z9Rn7T52TNVHjsQvPOfsvt7eXr388svq6urS3r179eGHH8oYc0nzB2xBpIAEMMbI7XZf8Ok2Rx+vMV2Kz75P6sMPP9S+ffuUm5urUCikXbt2KRqNDsr9AMlCpIAE6enpuWAkrl69+rJv3+FwSH87kjLGqK2tTStXrtTs2bP1ne98RxkZGUQKwx6RAhIgPT1dZ86cUU8fT985LvD+psriYqVf4PWr8Tk5mpqX97n3WVZWprlz5+rUqVOaNm2aFi1apDTeR4VhjkgBCTBy5Eh1d3frzJkzfb4ulO71Ku+f/iluWUVRkR69/nplulyxP0yXw6F8t1v/98YbVZqbGze+cMGC2H87HA653W5FIhG1tbXJ5/Np5MiRg/8pF8AQ459ZQAKMHj1avb29On78uIqLi89b7xwxQqP8fgX37lWko0PSp6GpKCrSuKwsbfnoIx3v6tL4nBzdOWGC8s85AcM9dqxG+f3nRaixsVEul0tXXXVV4jYOGEJECkiAgoICFRUVae/evbruuuvOW+9wOOS5/nqNmTdPrS+/LBOJxJZPHjVKk0eNuuBtp40apaJFi5Tm8Zy37q233tI111yj/Pz8QdsWIJl4ug9IgKysLPn9fm3btk29vb19jnG63Sr82teUd8st/f4MPldOjsbeeadyZ8w478NlT548qTfffFNf/epXeZoPKYNIAQngcDg0Z84cHTt2TDt27LjgOFdWlsZ94xsqXLBAGQUFF749l0sjxo9X8T33aMy8eXKe882+0WhUO3bsUDgc1uzZswdtO4Bk4+k+IEHGjh2r+fPna/369brhhhs0qo+n8BwOh9KyszX2jjvkmTpVJ3fuVGdjo8KBgKLhsFw5OcocN07e6dPlnT5dIy7wlR+hUEivvvqq5s6dq+zs7KHYPGBIECkgQRwOhxYvXqyFCxfqt7/9rRYuXHjBU8KdbrdyJk9W9t/9nSJnzsj09spEo3K4XHKmp8uZlXXBLziMRqPatm2b2traNH/+fLn4jimkEJ7uAxJo9OjR+ta3vqVf/vKX2rdv3+d+TJHD4ZDT7VZ6bq4yRo+Wu6BAGfn5SvN4LhgoY4wOHDigH//4x7r33nt11VVX8XoUUgqRAhLI6XTqy1/+sm666SZ9//vf11//+tdB+zw9Y4z+/Oc/69FHH9X8+fNVUVFBoJByiBSQYDk5OfrOd76jq666SitWrNC+ffsU+dsp55cqEono0KFD+va3v63S0lJVVVUpMzNzkGYM2INIAUMgOztbjz/+uK699lo99NBDevnll3Xq1KkB344xRqFQSK+88or+4z/+QzNnztR//ud/ytPHe6aAVOAww/Cz/EOhkLxer4LBIH+cGDbOBmbz5s1at26dpk6dqnvuuUdTp06NnezQ19N1Z/9Eo9Go/vjHP+qZZ57Re++9p29+85u69dZbE/f19EAC9fdxnEgBQywSieijjz7SM888o9raWk2cOFHz5s3TlClT5PV6lZmZKafTqUgkonA4rI6ODu3bt0+vvfaa/vKXv+iWW27Rv//7v2v8+PGcyYdhi0gBlotEInr33Xf12muvae/evTp+/LjS0tKUnp4uh8OhaDSqnp4e9fb2qqCgQGVlZSovL9eUKVM4QQLDXn8fx3mfFJAkLpdLX/ziF3Xttdfq5MmTOnbsmNrb23XmzBlFIhG5XC6NGDFCXq9XBQUFys3NveCXKAKpikgBSeZ0OpWfn8+HwgJ94J9lAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1hpQpNatW6epU6fK4/HI4/HI7/dr27ZtsfVdXV2qqqpSfn6+cnJytGDBArW2tsbdRnNzsyorK5WVlaWCggI9+OCD6u3tHZytAQCklAFFaty4cXriiSfU0NCgvXv36pZbbtGtt96qxsZGSdIDDzygzZs3a+PGjaqtrdWRI0d0++23x64fiURUWVmp7u5u7dy5Uy+88ILWr1+v1atXD+5WAQBSg7lMo0aNMj/72c9Me3u7SU9PNxs3boytO3z4sJFk6urqjDHGbN261TidThMIBGJj1q1bZzwejwmHw/2+z2AwaCSZYDB4udMHACRBfx/HL/k1qUgkog0bNujUqVPy+/1qaGhQT0+PysvLY2MmTZqkkpIS1dXVSZLq6uo0ZcoUFRYWxsZUVFQoFArFjsb6Eg6HFQqF4i4AgNQ34EgdPHhQOTk5crvduu+++7Rp0yaVlpYqEAgoIyNDubm5ceMLCwsVCAQkSYFAIC5QZ9efXXcha9askdfrjV2Ki4sHOm0AwDA04Ehdc801OnDggOrr67V06VItXrxY77zzTiLmFrNq1SoFg8HYpaWlJaH3BwCwQ9pAr5CRkaGrr75aklRWVqY9e/boRz/6ke688051d3ervb097miqtbVVPp9PkuTz+bR79+642zt79t/ZMX1xu91yu90DnSoAYJi77PdJRaNRhcNhlZWVKT09XTU1NbF1TU1Nam5ult/vlyT5/X4dPHhQbW1tsTHV1dXyeDwqLS293KkAAFLMgI6kVq1apXnz5qmkpEQdHR168cUX9eabb+r3v/+9vF6vlixZopUrVyovL08ej0f333+//H6/Zs2aJUmaM2eOSktLtWjRIj355JMKBAJ6+OGHVVVVxZESAOA8A4pUW1ubvv71r+uTTz6R1+vV1KlT9fvf/15f/vKXJUlPPfWUnE6nFixYoHA4rIqKCj377LOx67tcLm3ZskVLly6V3+9Xdna2Fi9erMcff3xwtwoAkBIcxhiT7EkMVCgUktfrVTAYlMfjSfZ0AAAD1N/HcT67DwBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1LitSTzzxhBwOh1asWBFb1tXVpaqqKuXn5ysnJ0cLFixQa2tr3PWam5tVWVmprKwsFRQU6MEHH1Rvb+/lTAUAkIIuOVJ79uzRT37yE02dOjVu+QMPPKDNmzdr48aNqq2t1ZEjR3T77bfH1kciEVVWVqq7u1s7d+7UCy+8oPXr12v16tWXvhUAgNRkLkFHR4eZOHGiqa6uNjfffLNZvny5McaY9vZ2k56ebjZu3Bgbe/jwYSPJ1NXVGWOM2bp1q3E6nSYQCMTGrFu3zng8HhMOh/t1/8Fg0EgywWDwUqYPAEiy/j6OX9KRVFVVlSorK1VeXh63vKGhQT09PXHLJ02apJKSEtXV1UmS6urqNGXKFBUWFsbGVFRUKBQKqbGxsc/7C4fDCoVCcRcAQOpLG+gVNmzYoH379mnPnj3nrQsEAsrIyFBubm7c8sLCQgUCgdiYzwbq7Pqz6/qyZs0aPfbYYwOdKgBgmBvQkVRLS4uWL1+uX/3qV8rMzEzUnM6zatUqBYPB2KWlpWXI7hsAkDwDilRDQ4Pa2tp0ww03KC0tTWlpaaqtrdXTTz+ttLQ0FRYWqru7W+3t7XHXa21tlc/nkyT5fL7zzvY7+/PZMedyu93yeDxxFwBA6htQpGbPnq2DBw/qwIEDscv06dO1cOHC2H+np6erpqYmdp2mpiY1NzfL7/dLkvx+vw4ePKi2trbYmOrqank8HpWWlg7SZgEAUsGAXpMaOXKkJk+eHLcsOztb+fn5seVLlizRypUrlZeXJ4/Ho/vvv19+v1+zZs2SJM2ZM0elpaVatGiRnnzySQUCAT388MOqqqqS2+0epM0CAKSCAZ84cTFPPfWUnE6nFixYoHA4rIqKCj377LOx9S6XS1u2bNHSpUvl9/uVnZ2txYsX6/HHHx/sqQAAhjmHMcYkexIDFQqF5PV6FQwGeX0KAIah/j6O89l9AABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrpSV7ApfCGCNJCoVCSZ4JAOBSnH38Pvt4fiHDMlLHjx+XJBUXFyd5JgCAy9HR0SGv13vB9cMyUnl5eZKk5ubmz924K10oFFJxcbFaWlrk8XiSPR1rsZ/6h/3UP+yn/jHGqKOjQ0VFRZ87blhGyun89KU0r9fLL0E/eDwe9lM/sJ/6h/3UP+yni+vPQQYnTgAArEWkAADWGpaRcrvdevTRR+V2u5M9Fauxn/qH/dQ/7Kf+YT8NLoe52Pl/AAAkybA8kgIAXBmIFADAWkQKAGAtIgUAsNawjNTatWs1fvx4ZWZmaubMmdq9e3eypzSkduzYofnz56uoqEgOh0Ovvvpq3HpjjFavXq2xY8dqxIgRKi8v13vvvRc35sSJE1q4cKE8Ho9yc3O1ZMkSdXZ2DuFWJNaaNWt04403auTIkSooKNBtt92mpqamuDFdXV2qqqpSfn6+cnJytGDBArW2tsaNaW5uVmVlpbKyslRQUKAHH3xQvb29Q7kpCbVu3TpNnTo19sZTv9+vbdu2xdazj/r2xBNPyOFwaMWKFbFl7KsEMcPMhg0bTEZGhvnFL35hGhsbzT333GNyc3NNa2trsqc2ZLZu3Wr+67/+y7zyyitGktm0aVPc+ieeeMJ4vV7z6quvmj/+8Y/mn//5n82ECRPMmTNnYmPmzp1rpk2bZnbt2mXeeustc/XVV5u77rpriLckcSoqKszzzz9vDh06ZA4cOGC+8pWvmJKSEtPZ2Rkbc99995ni4mJTU1Nj9u7da2bNmmX+/u//Pra+t7fXTJ482ZSXl5v9+/ebrVu3mtGjR5tVq1YlY5MS4re//a353e9+Z/785z+bpqYm873vfc+kp6ebQ4cOGWPYR33ZvXu3GT9+vJk6dapZvnx5bDn7KjGGXaRmzJhhqqqqYj9HIhFTVFRk1qxZk8RZJc+5kYpGo8bn85kf/OAHsWXt7e3G7Xabl156yRhjzDvvvGMkmT179sTGbNu2zTgcDvPxxx8P2dyHUltbm5FkamtrjTGf7pP09HSzcePG2JjDhw8bSaaurs4Y8+k/BpxOpwkEArEx69atMx6Px4TD4aHdgCE0atQo87Of/Yx91IeOjg4zceJEU11dbW6++eZYpNhXiTOsnu7r7u5WQ0ODysvLY8ucTqfKy8tVV1eXxJnZ44MPPlAgEIjbR16vVzNnzozto7q6OuXm5mr69OmxMeXl5XI6naqvrx/yOQ+FYDAo6f9/OHFDQ4N6enri9tOkSZNUUlISt5+mTJmiwsLC2JiKigqFQiE1NjYO4eyHRiQS0YYNG3Tq1Cn5/X72UR+qqqpUWVkZt08kfp8SaVh9wOyxY8cUiUTi/idLUmFhod59990kzcougUBAkvrcR2fXBQIBFRQUxK1PS0tTXl5ebEwqiUajWrFihb70pS9p8uTJkj7dBxkZGcrNzY0be+5+6ms/nl2XKg4ePCi/36+uri7l5ORo06ZNKi0t1YEDB9hHn7Fhwwbt27dPe/bsOW8dv0+JM6wiBVyKqqoqHTp0SG+//Xayp2Kla665RgcOHFAwGNT//u//avHixaqtrU32tKzS0tKi5cuXq7q6WpmZmcmezhVlWD3dN3r0aLlcrvPOmGltbZXP50vSrOxydj983j7y+Xxqa2uLW9/b26sTJ06k3H5ctmyZtmzZojfeeEPjxo2LLff5fOru7lZ7e3vc+HP3U1/78ey6VJGRkaGrr75aZWVlWrNmjaZNm6Yf/ehH7KPPaGhoUFtbm2644QalpaUpLS1NtbW1evrpp5WWlqbCwkL2VYIMq0hlZGSorKxMNTU1sWXRaFQ1NTXy+/1JnJk9JkyYIJ/PF7ePQqGQ6uvrY/vI7/ervb1dDQ0NsTHbt29XNBrVzJkzh3zOiWCM0bJly7Rp0yZt375dEyZMiFtfVlam9PT0uP3U1NSk5ubmuP108ODBuKBXV1fL4/GotLR0aDYkCaLRqMLhMPvoM2bPnq2DBw/qwIEDscv06dO1cOHC2H+zrxIk2WduDNSGDRuM2+0269evN++884659957TW5ubtwZM6muo6PD7N+/3+zfv99IMv/93/9t9u/fb/76178aYz49BT03N9f85je/MX/605/Mrbfe2ucp6Ndff72pr683b7/9tpk4cWJKnYK+dOlS4/V6zZtvvmk++eST2OX06dOxMffdd58pKSkx27dvN3v37jV+v9/4/f7Y+rOnDM+ZM8ccOHDAvPbaa2bMmDEpdcrwQw89ZGpra80HH3xg/vSnP5mHHnrIOBwO8/rrrxtj2Eef57Nn9xnDvkqUYRcpY4z58Y9/bEpKSkxGRoaZMWOG2bVrV7KnNKTeeOMNI+m8y+LFi40xn56G/sgjj5jCwkLjdrvN7NmzTVNTU9xtHD9+3Nx1110mJyfHeDwec/fdd5uOjo4kbE1i9LV/JJnnn38+NubMmTPmW9/6lhk1apTJysoyX/va18wnn3wSdzsffvihmTdvnhkxYoQZPXq0+fa3v216enqGeGsS5xvf+Ib5whe+YDIyMsyYMWPM7NmzY4Eyhn30ec6NFPsqMfiqDgCAtYbVa1IAgCsLkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANb6f6Vt7+hP5HB+AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-124.75456414359867"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
